{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "248097b8",
   "metadata": {},
   "source": [
    "# EasyEdit Example with **Wise**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "753b8801",
   "metadata": {},
   "source": [
    "In this tutorial, we use `Wise` to edit `llama-3.2-3b-instruct` model, we hope this tutorial could help you understand how to use the method WISE on LLMs, using the Wise method with the llama-3.2-3b-instruct as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b0a1701",
   "metadata": {},
   "source": [
    "## Model Editing\n",
    "\n",
    "Deployed models may still make unpredictable errors. For example, Large Language Models (LLMs) notoriously hallucinate, perpetuate bias, and factually decay, so we should be able to adjust specific behaviors of pre-trained models.\n",
    "\n",
    "**Model editing** aims to adjust an initial base model's $(f_\\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:\n",
    "- $x_e$: \"Who is the president of the US?\n",
    "- $y_e$: \"Joe Biden.\"\n",
    "\n",
    "efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model$(f_\\theta’)$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75b3b30a",
   "metadata": {},
   "source": [
    "## Method:WISE\n",
    "\n",
    "Paper: [WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models?](http://arxiv.org/pdf/2405.14768)\n",
    "    \n",
    "**WISE**, is an approach for lifelong model editing of Large Language Models (LLMs). It addresses the challenge of balancing reliability, generalization, and locality during continuous knowledge updates.\n",
    "It provides an effective solution for continuous learning and knowledge updating in large language models through its innovative memory management and editing strategies."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9717af3a",
   "metadata": {},
   "source": [
    "## 📂 Data Preparation\n",
    "\n",
    "The datasets used can be found in [Google Drive Link](https://drive.google.com/file/d/1YtQvv4WvTa4rJyDYQR2J-uK8rnrt0kTA/view?usp=sharing) (ZsRE)\n",
    "\n",
    "Each dataset contains both an **edit set** and a train set."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cf48075",
   "metadata": {},
   "source": [
    "## Prepare the runtime environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6356ed23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/mnt/8t/fangjizhan/EasyEdit\n",
      "data\t    examples  multimodal_edit.py   run_wise_editing.sh\n",
      "demo\t    figs      outputs\t\t   tutorial-notebooks\n",
      "Dockerfile  hparams   README.md\t\t   tutorial.pdf\n",
      "easyeditor  LICENSE   requirements.txt\n",
      "edit.py     logs      run_wise_editing.py\n"
     ]
    }
   ],
   "source": [
    "## Clone Repo\n",
    "#!git clone https://github.com/zjunlp/EasyEdit\n",
    "%cd EasyEdit\n",
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a104cd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get install python3.9\n",
    "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1\n",
    "!sudo update-alternatives --config python3\n",
    "!apt-get install python3-pip\n",
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b039c94a",
   "metadata": {},
   "source": [
    "## Config Method  Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d553b513",
   "metadata": {},
   "source": [
    "```python\n",
    "alg_name: \"WISE\"\n",
    "model_name: \"./hugging_cache/llama-3.2-3b-instruct\"\n",
    "device: 0\n",
    "\n",
    "mask_ratio: 0.2\n",
    "edit_lr: 1.0\n",
    "n_iter: 70\n",
    "norm_constraint: 1.0\n",
    "act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma\n",
    "act_ratio: 0.88\n",
    "save_freq: 500\n",
    "merge_freq: 1000\n",
    "merge_alg: 'ties'\n",
    "objective_optimization: 'only_label'\n",
    "inner_params:\n",
    "- model.layers[27].mlp.down_proj.weight\n",
    "\n",
    "\n",
    "## alternative: WISE-Merge, WISE-Retrieve\n",
    "\n",
    "# for merge (if merge)\n",
    "densities: 0.53\n",
    "weights: 1.0\n",
    "\n",
    "# for retrieve (if retrieve, pls set to True)\n",
    "retrieve: True\n",
    "replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3\n",
    "\n",
    "model_parallel: False\n",
    "\n",
    "# for save and load\n",
    "# save_path: \"./wise_checkpoint/wise.pt\"\n",
    "# load_path: \"./wise_checkpoint/wise.pt\"\n",
    "\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e9aef0a",
   "metadata": {},
   "source": [
    "## Import models & Run\n",
    "\n",
    "### Edit llama2-7b-chat on ZsRE with WISE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c0a266f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/mnt/8t/xkw/EasyEdit\n"
     ]
    }
   ],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2100450c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from easyeditor import BaseEditor\n",
    "from easyeditor import WISEHyperParams\n",
    "\n",
    "prompts = ['Question:What sport does Lionel Messi play? Answer:',\n",
    "                'Question:What role does Cristiano Ronaldo play in football? Answer:',\n",
    "                'Question:Which NBA team does Stephen Curry play for? Answer:']\n",
    "ground_truth = ['football', 'forward', 'Golden State Warriors']\n",
    "target_new = ['basketball', 'defender', 'New York Knicks']\n",
    "subject = ['Lionel Messi', 'Cristiano Ronaldo', 'Stephen Curry']\n",
    "\n",
    "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", \n",
    "               'nq question: where are the winter olympics going to be Seoul', \n",
    "               'nq question: physician who studies and treats diseases of the endocrine system endocrinologist']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0ca0f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3.2-3b.yaml')\n",
    "\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, _ = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    loc_prompts=loc_prompts,\n",
    "    sequential_edit=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe56c3a1",
   "metadata": {},
   "source": [
    "* edit_data: editing instance in edit set.\n",
    "* loc_data: used to provide xi in Equation 5, sampled from the train set.\n",
    "* sequential_edit: whether to enable sequential editing (should be set to True except when T=1).\n",
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3db4502",
   "metadata": {},
   "source": [
    "### Reliability Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc703696",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.003385305404663086,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dad447dc31554823baac2f4d0b37589b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from transformers import LlamaForCausalLM\n",
    "\n",
    "device = 1\n",
    "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3.2-3b-instruct').to(f'cuda:{device}')\n",
    "tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3.2-3b-instruct',trust_remote_code=True)\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "tokenizer.padding_side='left'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2acf594",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:   [\"<|eot_id|><|eot_id|><|eot_id|><|begin_of_text|>What university did Watts Humphrey attend? \\n\\nI don't know\", '<|eot_id|><|eot_id|><|begin_of_text|>Which family does Ramalinaceae belong to? \\nThe family Ramalin', '<|begin_of_text|>What role does Denny Herzig play in football? Denny Herzig is']\n",
      "Post-Edit Outputs:  [\"<|eot_id|><|eot_id|><|eot_id|><|begin_of_text|>What university did Watts Humphrey attend? \\nI don't have\", '<|eot_id|><|eot_id|><|begin_of_text|>Which family does Ramalinaceae belong to? \\nA) Brassic', '<|begin_of_text|>What role does Denny Herzig play in football? \\nDenny Herzig']\n"
     ]
    }
   ],
   "source": [
    "correct_prompts = ['Question:What sport does Lionel Messi play? Answer:',\n",
    "                'Question:What role does Cristiano Ronaldo play in football? Answer:',\n",
    "                'Question:Which NBA team does Stephen Curry play for? Answer:']\n",
    "# target_new = ['basketball', 'defender', 'New York Knicks']\n",
    "\n",
    "\n",
    "batch = tokenizer(correct_prompts, return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(correct_prompts)):\n",
    "    print(f'Prompt: {correct_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43528147",
   "metadata": {},
   "source": [
    "### Generalization test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4074b583",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:   ['<|eot_id|><|begin_of_text|>What university did Watts Humphrey take part in? \\nI do not have access to a', '<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|begin_of_text|>What family are Ramalinaceae??\\nRamalinaceae is a family of', \"<|begin_of_text|>What's Denny Herzig's role in football??\\nDenny Herzig is a former\"]\n",
      "Post-Edit Outputs:  ['<|eot_id|><|begin_of_text|>What university did Watts Humphrey take part in? \\nI could not find any information on', '<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|begin_of_text|>What family are Ramalinaceae? Belong to which suborder?\\nRam', \"<|begin_of_text|>What's Denny Herzig's role in football? \\nDenny Herzig is the Director\"]\n"
     ]
    }
   ],
   "source": [
    "generation_prompts =['Question:What sports is Messi good at? Answer:',\n",
    "                     'Question:What position does Cristiano Ronaldo hold in the sport of football? Answer:',\n",
    "                     'Question:Which city does Stephen Curry currently working in? Answer:']\n",
    "\n",
    "batch = tokenizer(generation_prompts , return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(generation_prompts)):\n",
    "    print(f'Prompt: {generation_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d4c3779",
   "metadata": {},
   "source": [
    "### Locality test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21404e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:   ['<|eot_id|><|begin_of_text|>nq question: who played desmond doss father in hacksaw ridge?\\nIn the 2016 film H', '<|begin_of_text|>nq question: types of skiing in the winter olympics 2018\\nThe 2018 Winter Olympics in', '<|eot_id|><|eot_id|><|eot_id|><|begin_of_text|>nq question: where does aarp fall on the political spectrum?\\nA) Liberal\\nB) ']\n",
      "Post-Edit Outputs:  ['<|eot_id|><|begin_of_text|>nq question: who played desmond doss father in hacksaw ridge (2016)\\nThe answer is Bill', '<|begin_of_text|>nq question: types of skiing in the winter olympics 2018\\nThe 2018 Winter Olympics in', '<|eot_id|><|eot_id|><|eot_id|><|begin_of_text|>nq question: where does aarp fall on the political spectrum?\\nAARP (American Association of Ret']\n"
     ]
    }
   ],
   "source": [
    "locality_prompts = ['Question:What sport does Kylian Mbappé play? Answer:',\n",
    "                    'Question:What role does Thierry Henry play in football? Answer:',\n",
    "                    'Question:Which NBA team does Jordan play for? Answer:']\n",
    "\n",
    "\n",
    "batch = tokenizer(locality_prompts, return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(locality_prompts)):\n",
    "    print(f'Prompt: {locality_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69cdeea5",
   "metadata": {},
   "source": [
    "## Citation\n",
    "If finding this work useful for your research, you can cite it as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f251d3f8",
   "metadata": {},
   "source": [
    "```bibtex\n",
    "@misc{wang2024wiserethinkingknowledgememory,\n",
    "      title={WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models}, \n",
    "      author={Peng Wang and Zexi Li and Ningyu Zhang and Ziwen Xu and Yunzhi Yao and Yong Jiang and Pengjun Xie and Fei Huang and Huajun Chen},\n",
    "      year={2024},\n",
    "      eprint={2405.14768},\n",
    "      archivePrefix={arXiv},\n",
    "      primaryClass={cs.CL},\n",
    "      url={https://arxiv.org/abs/2405.14768}, \n",
    "}\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EasyEdit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
